Skip to content

第14章 RAG知识库系统实现

学习目标

  • 掌握完整RAG知识库系统的设计与实现流程
  • 学习构建高质量文档处理与索引流水线
  • 理解RAG系统中各组件的协同工作机制
  • 了解如何评估和优化RAG知识库系统性能

RAG知识库系统概述

RAG知识库系统是一种将检索增强与知识管理相结合的智能系统,可以帮助用户高效地获取、理解和应用大量专业知识。

参考课程视频中的内容

核心组件与流程

一个完整的RAG知识库系统通常包含以下核心组件:

  1. 文档处理与索引系统:负责文档的采集、预处理、分割和索引
  2. 检索系统:支持多种检索策略,找到最相关的信息
  3. 生成系统:使用检索结果作为上下文,生成高质量回答
  4. 用户交互界面:提供直观的交互方式,支持问答和反馈
  5. 评估与监控系统:持续评估系统性能,收集反馈以优化系统

文档处理与索引系统实现

1. 文档采集与预处理

首先,我们需要构建一个灵活的文档加载和预处理流水线:

python
from langchain.document_loaders import (
    PyPDFLoader,
    TextLoader,
    CSVLoader,
    UnstructuredMarkdownLoader,
    WebBaseLoader
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
import os
import re

class DocumentProcessor:
    def __init__(self):
        # 支持的文件类型和对应的加载器
        self.loaders = {
            ".pdf": PyPDFLoader,
            ".txt": TextLoader,
            ".csv": CSVLoader,
            ".md": UnstructuredMarkdownLoader
        }
        # 初始化文本分割器
        self.text_splitter = RecursiveCharacterTextSplitter(
            chunk_size=1000,
            chunk_overlap=200
        )
    
    def load_directory(self, directory_path):
        """加载整个目录的文档"""
        documents = []
        for root, _, files in os.walk(directory_path):
            for file in files:
                file_path = os.path.join(root, file)
                try:
                    file_documents = self.load_file(file_path)
                    documents.extend(file_documents)
                except Exception as e:
                    print(f"Error loading {file_path}: {e}")
        return documents
    
    def load_file(self, file_path):
        """加载单个文件"""
        ext = os.path.splitext(file_path)[1].lower()
        if ext in self.loaders:
            loader = self.loaders[ext](file_path)
            return loader.load()
        else:
            raise ValueError(f"Unsupported file type: {ext}")
    
    def load_web(self, urls):
        """加载网页内容"""
        loader = WebBaseLoader(urls)
        return loader.load()
    
    def preprocess_text(self, text):
        """文本预处理"""
        # 移除多余空白
        text = re.sub(r'\s+', ' ', text).strip()
        # 移除特殊字符
        text = re.sub(r'[^\w\s.,;:!?()[\]{}"\'-]', '', text)
        return text
    
    def split_documents(self, documents):
        """文档分割"""
        return self.text_splitter.split_documents(documents)
    
    def process_documents(self, source, source_type="directory"):
        """完整处理流程"""
        # 加载文档
        if source_type == "directory":
            documents = self.load_directory(source)
        elif source_type == "file":
            documents = self.load_file(source)
        elif source_type == "web":
            documents = self.load_web(source)
        else:
            raise ValueError(f"Unsupported source type: {source_type}")
        
        # 预处理和分割
        for doc in documents:
            doc.page_content = self.preprocess_text(doc.page_content)
        
        return self.split_documents(documents)

# 使用文档处理器
processor = DocumentProcessor()
documents = processor.process_documents("./knowledge_base", source_type="directory")
print(f"Processed {len(documents)} document chunks")

2. 高级文档分割策略

对于不同类型的文档,我们可以使用更智能的分割策略:

python
from langchain.text_splitter import (
    RecursiveCharacterTextSplitter,
    MarkdownTextSplitter,
    PythonCodeTextSplitter,
    HTMLTextSplitter
)
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms import DeepSeek

class SmartDocumentSplitter:
    def __init__(self, llm):
        self.llm = llm
        
        # 初始化不同类型的分割器
        self.splitters = {
            "default": RecursiveCharacterTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            ),
            "markdown": MarkdownTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            ),
            "python": PythonCodeTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            ),
            "html": HTMLTextSplitter(
                chunk_size=1000,
                chunk_overlap=200
            )
        }
        
        # 创建语义分割链
        semantic_split_template = """
        分析以下文本并找出最佳的语义分割点,将其分成多个连贯且相对独立的段落。
        每个段落应保持内部的语义完整性,最好是在自然的主题转换点进行分割。
        
        原始文本:
        {text}
        
        请提供分割点的位置(字符索引),每个索引占一行:
        """
        
        self.semantic_splitter = LLMChain(
            llm=llm,
            prompt=PromptTemplate(
                input_variables=["text"],
                template=semantic_split_template
            )
        )
    
    def split_by_type(self, document, doc_type="default"):
        """根据文档类型选择分割器"""
        splitter = self.splitters.get(doc_type, self.splitters["default"])
        return splitter.split_documents([document])
    
    def split_by_semantic(self, document, max_chunk_size=1000):
        """基于语义内容进行智能分割"""
        text = document.page_content
        
        # 对于短文本,直接返回
        if len(text) <= max_chunk_size:
            return [document]
        
        # 使用LLM找出语义分割点
        try:
            split_response = self.semantic_splitter.run(text=text[:min(len(text), 4000)])
            split_indices = [int(idx.strip()) for idx in split_response.split("\n") if idx.strip().isdigit()]
            
            # 确保分割点有效
            split_indices = [idx for idx in split_indices if 0 < idx < len(text)]
            split_indices = sorted(split_indices)
            
            # 如果没有有效分割点,回退到默认分割
            if not split_indices:
                return self.split_by_type(document)
            
            # 根据分割点创建文档块
            chunks = []
            start_idx = 0
            
            for idx in split_indices:
                chunk_text = text[start_idx:idx].strip()
                if chunk_text:
                    chunk_doc = document.copy()
                    chunk_doc.page_content = chunk_text
                    chunks.append(chunk_doc)
                start_idx = idx
            
            # 添加最后一个块
            if start_idx < len(text):
                chunk_text = text[start_idx:].strip()
                if chunk_text:
                    chunk_doc = document.copy()
                    chunk_doc.page_content = chunk_text
                    chunks.append(chunk_doc)
            
            return chunks
        except Exception as e:
            print(f"Error in semantic splitting: {e}")
            return self.split_by_type(document)
    
    def split_mixed_content(self, document):
        """处理混合内容文档"""
        # 检测文档类型
        content = document.page_content.lower()
        
        if "```python" in content or "```py" in content:
            # 包含Python代码
            return self.split_by_type(document, "python")
        elif content.count("#") > 5 or "---" in content:
            # 可能是Markdown
            return self.split_by_type(document, "markdown")
        elif "<html" in content or "<div" in content or "<p>" in content:
            # 可能是HTML
            return self.split_by_type(document, "html")
        else:
            # 尝试语义分割
            return self.split_by_semantic(document)

# 使用智能分割器
llm = DeepSeek(api_key="your-api-key")
smart_splitter = SmartDocumentSplitter(llm)

# 处理文档
document = documents[0]  # 假设已有documents列表
chunks = smart_splitter.split_mixed_content(document)
print(f"Split into {len(chunks)} chunks")

3. 元数据提取与增强

为文档添加丰富的元数据信息,可以显著提升检索效果:

python
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import hashlib
import datetime

class MetadataEnhancer:
    def __init__(self, llm):
        self.llm = llm
        
        # 创建主题提取链
        topic_extract_template = """
        分析以下文本,提取3-5个最能代表其核心主题的关键词或短语。

        文本:
        {text}

        关键主题(以逗号分隔):
        """
        
        self.topic_extractor = LLMChain(
            llm=llm,
            prompt=PromptTemplate(
                input_variables=["text"],
                template=topic_extract_template
            )
        )
        
        # 创建摘要生成链
        summary_template = """
        为以下文本生成一个简洁的摘要(50字以内)。

        文本:
        {text}

        摘要:
        """
        
        self.summarizer = LLMChain(
            llm=llm,
            prompt=PromptTemplate(
                input_variables=["text"],
                template=summary_template
            )
        )
    
    def extract_metadata(self, document):
        """提取和增强文档元数据"""
        # 复制原有元数据
        metadata = document.metadata.copy() if hasattr(document, "metadata") else {}
        
        text = document.page_content
        
        # 基础元数据
        metadata["doc_id"] = hashlib.md5(text.encode()).hexdigest()
        metadata["char_count"] = len(text)
        metadata["word_count"] = len(text.split())
        metadata["processed_date"] = datetime.datetime.now().isoformat()
        
        # 提取主题
        try:
            topics_text = self.topic_extractor.run(text=text[:min(len(text), 3000)])
            topics = [topic.strip() for topic in topics_text.split(",")]
            metadata["topics"] = topics
        except Exception as e:
            print(f"Error extracting topics: {e}")
        
        # 生成摘要
        try:
            metadata["summary"] = self.summarizer.run(text=text[:min(len(text), 3000)])
        except Exception as e:
            print(f"Error generating summary: {e}")
        
        # 检测语言(简单实现)
        chinese_char_ratio = len([c for c in text if '\u4e00' <= c <= '\u9fff']) / max(len(text), 1)
        metadata["language"] = "zh" if chinese_char_ratio > 0.1 else "en"
        
        # 更新文档元数据
        document.metadata = metadata
        return document
    
    def enhance_batch(self, documents):
        """批量增强文档元数据"""
        enhanced_docs = []
        for doc in documents:
            enhanced_docs.append(self.extract_metadata(doc))
        return enhanced_docs

# 使用元数据增强器
enhancer = MetadataEnhancer(llm)
enhanced_documents = enhancer.enhance_batch(chunks)
print(f"Enhanced {len(enhanced_documents)} documents with metadata")

4. 混合索引构建

构建同时支持向量检索和关键词检索的混合索引:

python
from langchain.vectorstores import Chroma
from langchain.embeddings import DeepSeekEmbeddings
from langchain.retrievers.bm25 import BM25Retriever
import pickle
import os

class HybridIndexBuilder:
    def __init__(self, embeddings, persist_directory="./hybrid_index"):
        self.embeddings = embeddings
        self.persist_directory = persist_directory
        os.makedirs(persist_directory, exist_ok=True)
    
    def build_vector_index(self, documents):
        """构建向量索引"""
        vector_db = Chroma.from_documents(
            documents=documents,
            embedding=self.embeddings,
            persist_directory=os.path.join(self.persist_directory, "vector_db")
        )
        vector_db.persist()
        return vector_db
    
    def build_keyword_index(self, documents):
        """构建关键词索引"""
        bm25_retriever = BM25Retriever.from_documents(documents)
        # 保存BM25检索器
        with open(os.path.join(self.persist_directory, "bm25.pkl"), "wb") as f:
            pickle.dump(bm25_retriever, f)
        return bm25_retriever
    
    def build_hybrid_index(self, documents):
        """构建混合索引"""
        print("Building vector index...")
        vector_db = self.build_vector_index(documents)
        
        print("Building keyword index...")
        bm25_retriever = self.build_keyword_index(documents)
        
        # 保存文档ID映射
        doc_ids = {doc.metadata.get("doc_id", i): i for i, doc in enumerate(documents)}
        with open(os.path.join(self.persist_directory, "doc_ids.pkl"), "wb") as f:
            pickle.dump(doc_ids, f)
        
        print("Hybrid index built successfully")
        return {
            "vector_db": vector_db,
            "bm25_retriever": bm25_retriever,
            "doc_ids": doc_ids
        }
    
    def load_hybrid_index(self):
        """加载已有的混合索引"""
        # 加载向量数据库
        vector_db = Chroma(
            persist_directory=os.path.join(self.persist_directory, "vector_db"),
            embedding_function=self.embeddings
        )
        
        # 加载BM25检索器
        with open(os.path.join(self.persist_directory, "bm25.pkl"), "rb") as f:
            bm25_retriever = pickle.load(f)
        
        # 加载文档ID映射
        with open(os.path.join(self.persist_directory, "doc_ids.pkl"), "rb") as f:
            doc_ids = pickle.load(f)
        
        return {
            "vector_db": vector_db,
            "bm25_retriever": bm25_retriever,
            "doc_ids": doc_ids
        }

# 初始化嵌入模型
embeddings = DeepSeekEmbeddings(api_key="your-api-key")

# 构建混合索引
index_builder = HybridIndexBuilder(embeddings)
indices = index_builder.build_hybrid_index(enhanced_documents)

# 加载已有索引
# loaded_indices = index_builder.load_hybrid_index()

高性能检索系统实现

基于前一节学习的混合检索策略,我们设计一个完整的高性能检索系统:

python
from langchain.retrievers import EnsembleRetriever
from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
import numpy as np

class AdvancedRetrievalSystem:
    def __init__(self, hybrid_indices, llm, embeddings):
        self.vector_db = hybrid_indices["vector_db"]
        self.bm25_retriever = hybrid_indices["bm25_retriever"]
        self.doc_ids = hybrid_indices["doc_ids"]
        self.llm = llm
        self.embeddings = embeddings
        
        # 创建检索器
        self.setup_retrievers()
        
        # 创建查询优化器
        self.setup_query_optimizer()
    
    def setup_retrievers(self):
        """设置各种检索器"""
        # 向量检索器
        self.vector_retriever = self.vector_db.as_retriever(
            search_kwargs={"k": 5}
        )
        
        # 集成检索器
        self.ensemble_retriever = EnsembleRetriever(
            retrievers=[self.bm25_retriever, self.vector_retriever],
            weights=[0.3, 0.7]
        )
    
    def setup_query_optimizer(self):
        """设置查询优化器"""
        query_optimizer_template = """
        请分析以下用户查询,并将其改写为更适合检索系统的形式。
        添加关键术语,消除歧义,并确保查询的意图清晰。

        用户查询: {query}

        改写后的查询:
        """
        
        self.query_optimizer = LLMChain(
            llm=self.llm,
            prompt=PromptTemplate(
                input_variables=["query"],
                template=query_optimizer_template
            )
        )
    
    def filter_similar_documents(self, documents, threshold=0.95):
        """过滤相似文档"""
        if not documents:
            return []
        
        # 计算所有文档的嵌入
        embeddings_list = []
        for doc in documents:
            embeddings_list.append(self.embeddings.embed_document(doc.page_content))
        
        # 计算相似度矩阵
        similarity_matrix = np.zeros((len(documents), len(documents)))
        for i in range(len(documents)):
            for j in range(i, len(documents)):
                if i == j:
                    similarity_matrix[i][j] = 1.0
                else:
                    # 计算余弦相似度
                    similarity = np.dot(embeddings_list[i], embeddings_list[j]) / (
                        np.linalg.norm(embeddings_list[i]) * np.linalg.norm(embeddings_list[j])
                    )
                    similarity_matrix[i][j] = similarity
                    similarity_matrix[j][i] = similarity
        
        # 贪婪选择不相似的文档
        selected_indices = []
        for i in range(len(documents)):
            # 检查当前文档是否与已选择的任何文档过于相似
            is_similar = False
            for selected_idx in selected_indices:
                if similarity_matrix[i][selected_idx] > threshold:
                    is_similar = True
                    break
            
            if not is_similar:
                selected_indices.append(i)
        
        # 返回过滤后的文档
        return [documents[i] for i in selected_indices]
    
    def retrieve(self, query, strategy="hybrid", optimize_query=True, filter_similar=True):
        """执行检索"""
        # 查询优化
        if optimize_query:
            try:
                optimized_query = self.query_optimizer.run(query=query)
                query = optimized_query
            except Exception as e:
                print(f"Error optimizing query: {e}")
        
        # 选择检索策略
        if strategy == "vector":
            results = self.vector_retriever.get_relevant_documents(query)
        elif strategy == "keyword":
            results = self.bm25_retriever.get_relevant_documents(query)
        elif strategy == "hybrid":
            results = self.ensemble_retriever.get_relevant_documents(query)
        else:
            raise ValueError(f"Unknown retrieval strategy: {strategy}")
        
        # 过滤相似文档
        if filter_similar and results:
            results = self.filter_similar_documents(results)
        
        return results

# 创建检索系统
retrieval_system = AdvancedRetrievalSystem(indices, llm, embeddings)

# 执行检索
results = retrieval_system.retrieve(
    "深度学习在自然语言处理中的应用",
    strategy="hybrid",
    optimize_query=True,
    filter_similar=True
)

print(f"Retrieved {len(results)} documents")
for i, doc in enumerate(results):
    print(f"Document {i+1}:")
    print(f"Summary: {doc.metadata.get('summary', 'N/A')}")
    print(f"Topics: {doc.metadata.get('topics', 'N/A')}")
    print(f"Content: {doc.page_content[:100]}...")
    print("-" * 50)

下一节我们将继续讨论RAG系统的生成部分和评估方法。

思考题

  1. 在构建RAG知识库时,如何确定最优的文档分割粒度?这个决策会如何影响检索效果?

  2. 对于包含大量表格、图表和代码的技术文档,应该采取什么特殊的处理策略?

  3. 元数据提取对RAG系统的检索性能有什么影响?哪些类型的元数据对不同领域的文档特别重要?

  4. 在实际应用中,如何平衡索引构建的时间复杂度和检索性能?有哪些优化技巧?